{% include '_header.py.jinja' %}
{% from '_utils.py.jinja' import recursive_types, active_provider with context %}
# -- template types.py.jinja --
from typing import TypeVar

import httpx
from . import _types
from .utils import _NoneType


{% set depth = generator.config.recursive_type_depth %}

# TODO: filters with aggregates should have their own recursive fields
# TODO: cleanup whitespace control
# TODO: add an argument to signify that the last iteration should be skipped
{% macro recursive(base) %}
    {% if recursive_types %}
{{ caller(base, base, '') }}
    {% else %}
    {%+ for i in range(depth) -%}
        {% if i == 0 %}
            {% set name = base %}
        {% else %}
            {% set name = base + 'Recursive%s' % i %}
        {% endif %}

        {% if i == depth - 1 %}
            {% set iteration = '' %}
        {% else %}
            {% set iteration = 'Recursive%s' % (i + 1) %}
        {% endif %}

        {% if iteration %}
            {% set next = base + iteration %}
        {% else %}
            {% set next = '' %}
        {% endif %}
{{ caller(name, next, iteration) }}
    {%- endfor %}
    {% endif %}
{% endmacro %}

{% macro render_type(type) %}
{# It is important that we render subtypes first so that references can be resolved correctly #}
{% for subtype in type.subtypes %}
{{ render_type(subtype) }}
{% endfor %}
{% if type.kind == 'alias' %}
{{ type.name }} = {{ type.to }}
{% elif type.kind == 'union' %}
{{ type.name }} = Union[
    {% for variant in type.variants %}
    '{{ variant.name }}',
    {% endfor %}
]
{% elif type.kind == 'typeddict' %}
{# We use the old syntax for defined TypedDicts so that fields that shadow keywords #}
{# can be defined, e.g. 'not' #}
{{ type.name }} = TypedDict(
    '{{ type.name }}',
    {
        {% for field, subtype in type.fields.items() %}
        '{{ field }}': {{ type_as_string(subtype) }},
        {% endfor %}
    },
    total={{ type.total }}
)
{% elif type.kind == 'enum' %}
class {{ type.name }}(StrEnum):
    {% for name, value in type.members %}
    {{ name }} = "{{ value }}"
    {% endfor %}
{% else %}
{{ raise_err('Unhandled type kind: %s' % type.kind) }}
{% endif %}
{% endmacro %}

SortMode = _types.SortMode
SortOrder = _types.SortOrder

MetricsFormat = _types.MetricsFormat

DatasourceOverride = _types.DatasourceOverride
HttpConfig = _types.HttpConfig


# types that can be serialized to json by our query builder
{% if recursive_types %}
Serializable = Union[
    None,
    bool,
    float,
    int,
    str,
    datetime.datetime,
    List['Serializable'],
    Dict[None, 'Serializable'],
    Dict[bool, 'Serializable'],
    Dict[float, 'Serializable'],
    Dict[int, 'Serializable'],
    Dict[str, 'Serializable'],
]
{% else %}
Serializable = Union[
    None,
    bool,
    float,
    int,
    str,
    datetime.datetime,
    List[Any],
    Dict[None, Any],
    Dict[bool, Any],
    Dict[float, Any],
    Dict[int, Any],
    Dict[str, Any],
]
{% endif %}


{% call(name, next, iteration) recursive('StringFilter') %}
{{ name }} = TypedDict(
    '{{ name }}',
    {
        'equals': str,
        'not_in': List[str],
        'lt': str,
        'lte': str,
        'gt': str,
        'gte': str,
        'contains': str,
        'startswith': str,
        'endswith': str,
        'in': List[str],
        {%+ if next != '' -%}
            'not': Union[str, '{{ next }}'],
        {% endif %}
        {%+ if active_provider in ['postgresql', 'cockroachdb', 'mongodb'] -%}
            'mode': SortMode,
        {% endif %}
        {%+ if active_provider in ['postgresql', 'mysql'] -%}
            'search': str,
        {% endif %}
    },
    total=False,
)
{% endcall %}


class StringWithAggregatesFilter(StringFilter, total=False):
    _max: 'StringFilter'
    _min: 'StringFilter'
    _sum: 'StringFilter'
    _avg: 'StringFilter'
    _count: 'IntFilter'


{% call(name, next, iteration) recursive('DateTimeFilter') %}
{{ name }} = TypedDict(
    '{{ name }}',
    {
        'equals': datetime.datetime,
        'not_in': List[datetime.datetime],
        'lt': datetime.datetime,
        'lte': datetime.datetime,
        'gt': datetime.datetime,
        'gte': datetime.datetime,
        'in': List[datetime.datetime],
        {%+ if next != '' -%}
            'not': Union[datetime.datetime, '{{ next }}'],
        {% endif %}
    },
    total=False,
)
{% endcall %}


class DateTimeWithAggregatesFilter(DateTimeFilter, total=False):
    _max: 'DateTimeFilter'
    _min: 'DateTimeFilter'
    _sum: 'DateTimeFilter'
    _avg: 'DateTimeFilter'
    _count: 'IntFilter'


{% call(name, next, iteration) recursive('BooleanFilter') %}
{{ name }} = TypedDict(
    '{{ name }}',
    {
        'equals': bool,
        {%+ if next != '' -%}
            'not': Union[bool, '{{ next }}'],
        {% endif %}
    },
    total=False,
)
{% endcall %}


class BooleanWithAggregatesFilter(BooleanFilter, total=False):
    _max: 'BooleanFilter'
    _min: 'BooleanFilter'
    _sum: 'BooleanFilter'
    _avg: 'BooleanFilter'
    _count: 'IntFilter'


{% call(name, next, iteration) recursive('IntFilter') %}
{{ name }} = TypedDict(
    '{{ name }}',
    {
        'equals': int,
        'not_in': List[int],
        'lt': int,
        'lte': int,
        'gt': int,
        'gte': int,
        'in': List[int],
        {%+ if next != '' -%}
            'not': Union[int, '{{ next }}'],
        {% endif %}
    },
    total=False,
)
{% endcall %}


class IntWithAggregatesFilter(IntFilter, total=False):
    _max: 'IntFilter'
    _min: 'IntFilter'
    _sum: 'IntFilter'
    _avg: 'IntFilter'
    _count: 'IntFilter'


BigIntFilter = IntFilter
BigIntWithAggregatesFilter = IntWithAggregatesFilter
{% call(name, next, iteration) recursive('FloatFilter') %}
{{ name }} = TypedDict(
    '{{ name }}',
    {
        'equals': float,
        'not_in': List[float],
        'lt': float,
        'lte': float,
        'gt': float,
        'gte': float,
        'in': List[float],
        {%+ if next != '' -%}
            'not': Union[float, '{{ next }}'],
        {% endif %}
    },
    total=False,
)
{% endcall %}


class FloatWithAggregatesFilter(FloatFilter, total=False):
    _max: 'FloatFilter'
    _min: 'FloatFilter'
    _sum: 'FloatFilter'
    _avg: 'FloatFilter'
    _count: 'IntFilter'


{% call(name, next, iteration) recursive('BytesFilter') %}
{{ name }} = TypedDict(
    '{{ name }}',
    {
        'equals': 'fields.Base64',
        'in': List['fields.Base64'],
        'not_in': List['fields.Base64'],
        {%+ if next != '' -%}
            'not': Union['fields.Base64', '{{ next }}'],
        {% endif %}
    },
    total=False,
)
{% endcall %}


class BytesWithAggregatesFilter(BytesFilter, total=False):
    _max: 'BytesFilter'
    _min: 'BytesFilter'
    _sum: 'BytesFilter'
    _avg: 'BytesFilter'
    _count: 'IntFilter'


# TODO: preview feature for improving JSON filtering
JsonFilter = TypedDict(
    'JsonFilter',
    {
        'equals': 'fields.Json',
        'not': 'fields.Json',
    },
    total=False,
)


class JsonWithAggregatesFilter(JsonFilter, total=False):
    _max: 'JsonFilter'
    _min: 'JsonFilter'
    _sum: 'JsonFilter'
    _avg: 'JsonFilter'
    _count: 'IntFilter'


{% call(name, next, iteration) recursive('DecimalFilter') %}
{{ name }} = TypedDict(
    '{{ name }}',
    {
        'equals': decimal.Decimal,
        'not_in': List[decimal.Decimal],
        'lt': decimal.Decimal,
        'lte': decimal.Decimal,
        'gt': decimal.Decimal,
        'gte': decimal.Decimal,
        'in': List[decimal.Decimal],
        {%+ if next != '' -%}
            'not': Union[decimal.Decimal, '{{ next }}'],
        {% endif %}
    },
    total=False,
)
{% endcall %}


class DecimalWithAggregatesFilter(StringFilter, total=False):
    _max: 'DecimalFilter'
    _min: 'DecimalFilter'
    _sum: 'DecimalFilter'
    _avg: 'DecimalFilter'
    _count: 'IntFilter'


class _FloatSetInput(TypedDict):
    set: float


class _FloatDivideInput(TypedDict):
    divide: float


class _FloatMultiplyInput(TypedDict):
    multiply: float


class _FloatIncrementInput(TypedDict):
    increment: float


class _FloatDecrementInput(TypedDict):
    decrement: float


class _IntSetInput(TypedDict):
    set: int


class _IntDivideInput(TypedDict):
    divide: int


class _IntMultiplyInput(TypedDict):
    multiply: int


class _IntIncrementInput(TypedDict):
    increment: int


class _IntDecrementInput(TypedDict):
    decrement: int


AtomicFloatInput = Union[
    _FloatSetInput,
    _FloatDivideInput,
    _FloatMultiplyInput,
    _FloatIncrementInput,
    _FloatDecrementInput,
]
AtomicIntInput = Union[
    _IntSetInput,
    _IntDivideInput,
    _IntMultiplyInput,
    _IntIncrementInput,
    _IntDecrementInput,
]
AtomicBigIntInput = AtomicIntInput

{% for type, python_type in get_list_types() %}
class _{{ type }}ListFilterEqualsInput(TypedDict):
    equals: Optional[List[{{ python_type }}]]


class _{{ type }}ListFilterHasInput(TypedDict):
    has: {{ python_type }}


class _{{ type }}ListFilterHasEveryInput(TypedDict):
    has_every: List[{{ python_type }}]


class _{{ type }}ListFilterHasSomeInput(TypedDict):
    has_some: List[{{ python_type }}]


class _{{ type }}ListFilterIsEmptyInput(TypedDict):
    is_empty: bool


{{ type }}ListFilter = Union[
    _{{ type }}ListFilterHasInput,
    _{{ type }}ListFilterEqualsInput,
    _{{ type }}ListFilterHasSomeInput,
    _{{ type }}ListFilterIsEmptyInput,
    _{{ type }}ListFilterHasEveryInput,
]


class _{{ type }}ListUpdateSet(TypedDict):
    set: List[{{ python_type }}]


{% if active_provider != 'cockroachdb' %}
class _{{ type }}ListUpdatePush(TypedDict):
    push: List[{{ python_type }}]
{% endif %}


{{ type }}ListUpdate = Union[
    List[{{ python_type }}],
    _{{ type }}ListUpdateSet,
    {% if active_provider != 'cockroachdb' %}
    _{{ type }}ListUpdatePush,
    {% endif %}
]

{% endfor %}

{% for model in dmmf.datamodel.models %}
{% set model_schema = type_schema.get_model(model.name) %}
# {{ model.name }} types

class {{ model.name }}OptionalCreateInput(TypedDict, total=False):
    """Optional arguments to the {{ model.name }} create method"""
    {% for field in model.all_fields %}
        {%- if not field.required_on_create or field.is_read_only -%}
            {{'    '}}{{ field.name }}: {{ field.maybe_optional(field.create_input_type) }}
        {% endif %}
    {% endfor %}


class {{ model.name }}CreateInput({{ model.name }}OptionalCreateInput):
    """Required arguments to the {{ model.name }} create method"""
    {% for field in model.all_fields %}
        {%- if field.required_on_create and not field.is_read_only -%}
            {{'    '}}{{ field.name }}: {{ field.maybe_optional(field.create_input_type) }}
        {% endif %}
    {% endfor %}


# TODO: remove this in favour of without explicit relations
# e.g. PostCreateWithoutAuthorInput

class {{ model.name }}OptionalCreateWithoutRelationsInput(TypedDict, total=False):
    """Optional arguments to the {{ model.name }} create method, without relations"""
    {% for field in model.scalar_fields %}
        {%- if not field.required_on_create or field.is_read_only -%}
            {{'    '}}{{ field.name }}: {{ field.maybe_optional(field.create_input_type) }}
        {% endif %}
    {% endfor %}


class {{ model.name }}CreateWithoutRelationsInput({{ model.name }}OptionalCreateWithoutRelationsInput):
    """Required arguments to the {{ model.name }} create method, without relations"""
    {% for field in model.scalar_fields %}
        {%- if field.required_on_create and not field.is_read_only -%}
            {{'    '}}{{ field.name }}: {{ field.maybe_optional(field.create_input_type) }}
        {% endif %}
    {% endfor %}

class {{ model.name }}ConnectOrCreateWithoutRelationsInput(TypedDict):
    create: '{{ model.name }}CreateWithoutRelationsInput'
    where: '{{ model.name }}WhereUniqueInput'

class {{ model.name }}CreateNestedWithoutRelationsInput(TypedDict, total=False):
    create: '{{ model.name }}CreateWithoutRelationsInput'
    connect: '{{ model.name }}WhereUniqueInput'
    connect_or_create: '{{ model.name }}ConnectOrCreateWithoutRelationsInput'


class {{ model.name }}CreateManyNestedWithoutRelationsInput(TypedDict, total=False):
    create: Union['{{ model.name }}CreateWithoutRelationsInput', List['{{ model.name }}CreateWithoutRelationsInput']]
    connect: Union['{{ model.name }}WhereUniqueInput', List['{{ model.name }}WhereUniqueInput']]
    connect_or_create: Union['{{ model.name }}ConnectOrCreateWithoutRelationsInput', List['{{ model.name }}ConnectOrCreateWithoutRelationsInput']]

{{ render_type(model_schema.where_unique) }}

class {{ model.name }}UpdateInput(TypedDict, total=False):
    """Optional arguments for updating a record"""
    {% for field in model.all_fields %}
        {%- if not field.is_read_only -%}
            {{'    '}}{{ field.name }}: {{ field.maybe_optional(field.get_update_input_type()) }}
        {% endif %}
    {% endfor %}


class {{ model.name }}UpdateManyMutationInput(TypedDict, total=False):
    """Arguments for updating many records"""
    {% for field in model.scalar_fields %}
        {%- if not field.is_read_only -%}
            {{'    '}}{{ field.name }}: {{ field.maybe_optional(field.get_update_input_type()) }}
        {% endif %}
    {% endfor %}


class {{ model.name }}UpdateManyWithoutRelationsInput(TypedDict, total=False):
    create: List['{{ model.name }}CreateWithoutRelationsInput']
    connect: List['{{ model.name }}WhereUniqueInput']
    connect_or_create: List['{{ model.name }}ConnectOrCreateWithoutRelationsInput']
    set: List['{{ model.name }}WhereUniqueInput']
    disconnect: List['{{ model.name }}WhereUniqueInput']
    delete: List['{{ model.name }}WhereUniqueInput']

    # TODO
    # update: List['{{ model.name }}UpdateWithWhereUniqueWithoutRelationsInput']
    # updateMany: List['{{ model.name }}UpdateManyWithWhereUniqueWithoutRelationsInput']
    # deleteMany: List['{{ model.name }}ScalarWhereInput']
    # upsert: List['{{ model.name }}UpserteWithWhereUniqueWithoutRelationsInput']


class {{ model.name }}UpdateOneWithoutRelationsInput(TypedDict, total=False):
    create: '{{ model.name }}CreateWithoutRelationsInput'
    connect: '{{ model.name }}WhereUniqueInput'
    connect_or_create: '{{ model.name }}ConnectOrCreateWithoutRelationsInput'
    disconnect: bool
    delete: bool

    # TODO
    # update: '{{ model.name }}UpdateInput'
    # upsert: '{{ model.name }}UpsertWithoutRelationsInput'


class {{ model.name }}UpsertInput(TypedDict):
    create: '{{ model.name }}CreateInput'
    update: '{{ model.name }}UpdateInput'  # pyright: ignore[reportIncompatibleMethodOverride]


{{ render_type(model_schema.order_by) }}


# recursive {{ model.name }} types
# TODO: cleanup these types


{% if recursive_types %}
    {% set where_input_type = model.name + 'WhereInput' %}
{% else %}
# Dict[str, Any] is a mypy limitation
# see https://github.com/RobertCraigie/prisma-client-py/issues/45
# switch to pyright for improved types, see https://prisma-client-py.readthedocs.io/en/stable/reference/limitations/
    {% set where_input_type = 'Dict[str, Any]' %}
{% endif %}

{{ model.name }}RelationFilter = TypedDict(
    '{{ model.name }}RelationFilter',
    {
        'is': '{{ where_input_type }}',
        'is_not': '{{ where_input_type }}',
    },
    total=False,
)


class {{ model.name }}ListRelationFilter(TypedDict, total=False):
    some: '{{ where_input_type }}'
    none: '{{ where_input_type }}'
    every: '{{ where_input_type }}'


class {{ model.name }}Include(TypedDict, total=False):
    """{{ model.name }} relational arguments"""
    {% for field in model.relational_fields -%}
        {{'    '}}{{ field.name }}: Union[bool, '{{ field.relational_args_type }}From{{ model.name }}']
    {% endfor %}


{% for related in dmmf.datamodel.models %}
{% call(name, next, iteration) recursive('%sIncludeFrom%s' % (related.name, model.name)) %}
class {{ name }}(TypedDict, total=False):
    """Relational arguments for {{model.name }}"""
    {% if next != '' -%}
        {% for field in related.relational_fields -%}
            {{'    '}}{{ field.name }}: Union[bool, '{{ field.relational_args_type }}From{{ model.name + iteration }}']
        {% endfor %}
    {% endif %}
{% endcall %}

{% call(name, next, iteration) recursive('%sArgsFrom%s' % (related.name, model.name)) %}
class {{ name }}(TypedDict, total=False):
    """Arguments for {{model.name }}"""
    {%+ if next != '' -%}
        include: '{{ related.name }}IncludeFrom{{ related.name + iteration}}'
    {% endif %}
{% endcall %}

{% call(name, next, iteration) recursive('FindMany%sArgsFrom%s' % (related.name, model.name)) %}
class {{ name }}(TypedDict, total=False):
    """Arguments for {{model.name }}"""
    take: int
    skip: int
    order_by: Union['{{ related.name }}OrderByInput', List['{{ related.name }}OrderByInput']]
    where: '{{ related.name }}WhereInput'
    cursor: '{{ related.name }}WhereUniqueInput'
    distinct: List['{{ related.name }}ScalarFieldKeys']
    {%+ if next != '' -%}
        include: '{{ related.name }}IncludeFrom{{ related.name + iteration}}'
    {% endif %}
{% endcall %}

{% endfor %}


FindMany{{ model.name }}Args = FindMany{{ model.name }}ArgsFrom{{ model.name }}
FindFirst{{ model.name }}Args = FindMany{{ model.name }}ArgsFrom{{ model.name}}


{% call(current, next, iteration) recursive(model.name + 'WhereInput') %}
class {{ current }}(TypedDict, total=False):
    """{{ model.name }} arguments for searching"""
    {% for field in model.all_fields -%}
        {{'    '}}{{ field.name }}: {{ field.where_input_type }}
    {% endfor %}

    {% if next != '' %}
    # should be noted that AND and NOT should be Union['{{ next }}', List['{{ next }}']]
    # but this causes mypy to hang :/
    AND: List['{{ next }}']
    OR: List['{{ next }}']
    NOT: List['{{ next }}']
    {% endif %}
{% endcall %}


# aggregate {{ model.name }} types


{% call(current, next, iteration) recursive(model.name + 'ScalarWhereWithAggregatesInput') %}
class {{ current }}(TypedDict, total=False):
    """{{ model.name }} arguments for searching"""
    {% for field in model.scalar_fields %}
    {{ field.name }}: {{ field.where_aggregates_input_type }}
    {% endfor %}

    {% if next != '' %}
    AND: List['{{ next }}']
    OR: List['{{ next }}']
    NOT: List['{{ next }}']
    {% endif %}
{% endcall %}


class {{ model.name }}GroupByOutput(TypedDict, total=False):
    {% for field in model.scalar_fields -%}
        {{ '    ' }}{{ field.name }}: {{ field.python_type }}
    {% endfor %}
    _sum: '{{ model.name }}SumAggregateOutput'
    _avg: '{{ model.name }}AvgAggregateOutput'
    _min: '{{ model.name }}MinAggregateOutput'
    _max: '{{ model.name }}MaxAggregateOutput'
    _count: '{{ model.name }}CountAggregateOutput'


class {{ model.name }}AvgAggregateOutput(TypedDict, total=False):
    """{{ model.name }} output for aggregating averages"""
    {% for field in model.scalar_fields %}
        {% if field.is_number -%}
            {{ '    ' }}{{ field.name }}: float
        {% endif %}
    {% endfor %}


class {{ model.name }}SumAggregateOutput(TypedDict, total=False):
    """{{ model.name }} output for aggregating sums"""
    {% for field in model.scalar_fields %}
        {% if field.is_number -%}
            {{ '    ' }}{{ field.name }}: {{ field.python_type }}
        {% endif %}
    {% endfor %}


class {{ model.name }}ScalarAggregateOutput(TypedDict, total=False):
    """{{ model.name }} output including scalar fields"""
    {% for field in model.scalar_fields -%}
        {{ '    ' }}{{ field.name }}: {{ field.python_type }}
    {% endfor %}


{{ model.name }}MinAggregateOutput = {{ model.name }}ScalarAggregateOutput
{{ model.name }}MaxAggregateOutput = {{ model.name }}ScalarAggregateOutput


class {{ model.name }}MaxAggregateInput(TypedDict, total=False):
    """{{ model.name }} input for aggregating by max"""
    {% for field in model.scalar_fields -%}
        {{ '    ' }}{{ field.name }}: bool
    {% endfor %}


class {{ model.name }}MinAggregateInput(TypedDict, total=False):
    """{{ model.name }} input for aggregating by min"""
    {% for field in model.scalar_fields -%}
        {{ '    ' }}{{ field.name }}: bool
    {% endfor %}


class {{ model.name }}NumberAggregateInput(TypedDict, total=False):
    """{{ model.name }} input for aggregating numbers"""
    {% for field in model.scalar_fields %}
        {% if field.is_number -%}
            {{ '    ' }}{{ field.name }}: bool
        {% endif %}
    {% endfor %}


{{ model.name }}AvgAggregateInput = {{ model.name }}NumberAggregateInput
{{ model.name }}SumAggregateInput = {{ model.name }}NumberAggregateInput


{{ model.name }}CountAggregateInput = TypedDict(
    '{{ model.name }}CountAggregateInput',
    {
        {% for field in model.scalar_fields %}
        '{{ field.name }}': bool,
        {% endfor %}
        '_all': bool,
    },
    total=False,
)

{{ model.name }}CountAggregateOutput = TypedDict(
    '{{ model.name }}CountAggregateOutput',
    {
        {% for field in model.scalar_fields %}
        '{{ field.name }}': int,
        {% endfor %}
        '_all': int,
    },
    total=False,
)


{{ model.name }}Keys = Literal[
    {% for field in model.all_fields %}
    '{{ field.name }}',
    {% endfor %}
]
{{ model.name }}ScalarFieldKeys = Literal[
    {% for field in model.scalar_fields %}
    '{{ field.name }}',
    {% endfor %}
]
{{ model.name }}ScalarFieldKeysT = TypeVar('{{ model.name }}ScalarFieldKeysT', bound={{ model.name }}ScalarFieldKeys)

{% if model.has_relational_fields -%}
    {{ model.name }}RelationalFieldKeys = Literal[
        {% for field in model.relational_fields %}
        '{{ field.name }}',
        {% endfor %}
    ]
{% else -%}
    {{ model.name }}RelationalFieldKeys = _NoneType
{% endif %}

{% endfor %}


# we have to import ourselves as types can be namespaced to types
from . import types, enums, models, fields
